"""
Model definition based on the NanoGPT speedrun: https://github.com/KellerJordan/modded-nanogpt

Corresponds to the Aug 23 record, the next record (Sept 3) requires unmerged PRs in the flash attention repo.

Commented out the unneeded parts of the script to preserve overall file structure for easier updating.
"""

import copy
import glob
import os
import sys
import time

# with open(sys.argv[0]) as f:
#     code = f.read() # read the code of this file ASAP, for logging
import uuid
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
import torch.nn.functional as F

# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min
import triton
import triton.language as tl

# torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
from torch import nn, Tensor

# use of FlexAttention contributed by @KoszarskyB
from torch.nn.attention.flex_attention import BlockMask, flex_attention

# -----------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng


@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(
    x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float
) -> tuple[Tensor, Tensor, Tensor]:
    @torch.compile
    def impl(x: Tensor, w: Tensor):
        assert x.is_contiguous() and w.is_contiguous()
        x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
        w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
        out = torch._scaled_mm(
            x_f8,
            w_f8.T,
            out_dtype=torch.bfloat16,
            scale_a=x.new_tensor(x_s, dtype=torch.float32),
            scale_b=x.new_tensor(w_s, dtype=torch.float32),
            use_fast_accum=True,
        )
        return out, x_f8, w_f8

    return impl(x, w)


@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
    assert x.ndim == w.ndim == 2
    assert x.shape[1] == w.shape[1]
    assert x.device == w.device
    assert x.is_contiguous() and w.is_contiguous()
    return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)


@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(
    g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float
) -> tuple[Tensor, Tensor]:
    @torch.compile
    def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
        assert grad.is_contiguous()
        x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
        w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
        grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
        grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
        grad_x = torch._scaled_mm(
            grad_f8,
            w_f8.T.contiguous().T,
            out_dtype=torch.bfloat16,
            scale_a=grad_inv_s,
            scale_b=w_inv_s,
            use_fast_accum=False,
        )
        # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
        grad_w = torch._scaled_mm(
            x_f8.T.contiguous(),
            grad_f8.T.contiguous().T,
            out_dtype=torch.float32,
            scale_a=x_inv_s,
            scale_b=grad_inv_s,
            use_fast_accum=False,
        ).T
        return grad_x, grad_w

    return impl(g, x_f8, w_f8)


@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
    return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)


def backward(ctx, grad_out: Tensor, *_):
    x_f8, w_f8 = ctx.saved_tensors
    x_s, w_s, grad_s = ctx.scales
    grad_x, grad_w = torch.ops.nanogpt.mm_backward(
        grad_out, x_f8, w_f8, x_s, w_s, grad_s
    )
    return grad_x, grad_w, None, None, None


def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
    *_, x_s, w_s, grad_s = inputs
    _, x_f8, w_f8 = output
    ctx.save_for_backward(x_f8, w_f8)
    ctx.scales = x_s, w_s, grad_s
    ctx.set_materialize_grads(False)


mm_op.register_autograd(backward, setup_context=setup_context)

# -----------------------------------------------------------------------------
# Triton kernel for symmetric matrix multiplication by @byronxu99


def _get_autotune_configs():
    return [
        triton.Config(
            {
                "BLOCK_SIZE_M": bm,
                "BLOCK_SIZE_N": bn,
                "BLOCK_SIZE_K": bk,
                "GROUP_SIZE_M": 8,
                "LOWER_UPPER": 1,
            },
            num_stages=stages,
            num_warps=warps,
        )
        for bm in [64, 128]
        for bn in [64, 128, 256]
        for bk in [64, 128]
        for stages, warps in [(3, 4), (3, 8), (4, 4)]
        if bm // bn <= 2 and bn // bm <= 2
    ]


@triton.jit
def _pid_to_block(
    pid,
    M,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(M, BLOCK_SIZE_N)

    # Map PID to a single matrix in batch
    batch_idx = pid // (num_pid_m * num_pid_n)
    pid = pid % (num_pid_m * num_pid_n)

    # Map PID to 2D grid of blocks
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)

    m_idx = pid_m * BLOCK_SIZE_M
    n_idx = pid_n * BLOCK_SIZE_N
    return batch_idx, m_idx, n_idx


@triton.autotune(
    configs=_get_autotune_configs(),
    key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
)
@triton.jit
def ns_line_1_kernel(
    A_ptr,
    C_ptr,
    M,
    K,
    a_stride_b,
    a_stride_r,
    a_stride_c,
    c_stride_b,
    c_stride_r,
    c_stride_c,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    LOWER_UPPER: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    batch_idx, m_idx, n_idx = _pid_to_block(
        pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
    )

    # Skip blocks that don't need to be computed
    skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
    skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
    if skip_block_below_diag or skip_block_above_diag:
        return

    # Index into one matrix of batch
    A_ptr += batch_idx * a_stride_b
    C_ptr += batch_idx * c_stride_b

    # Create pointer arrays for A and A.T
    offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
    at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Accumulate over blocks of K
    for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, at, accumulator)
        a_ptrs += BLOCK_SIZE_K * a_stride_c
        at_ptrs += BLOCK_SIZE_K * a_stride_c

    out_dtype = C_ptr.dtype.element_ty
    output = accumulator.to(out_dtype)

    # Store block of C
    offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
    tl.store(c_ptrs, output, mask=c_mask)

    # Store block of C mirrored across the diagonal
    c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
    c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
    tl.store(c_ptrs_t, output.T, mask=c_mask_t)


def ns_line_1(A: torch.Tensor, out: torch.Tensor):
    """
    Launch Triton kernel to compute C = A @ A.T
    """
    assert A.ndim == 2 or A.ndim == 3
    M, K = A.shape[-2:]
    assert out.size(-2) == M, "Output matrix has incorrect shape"
    assert out.size(-1) == M, "Output matrix has incorrect shape"

    batch_size = A.size(0) if A.ndim == 3 else 1
    input_batch_stride = A.stride(0) if A.ndim == 3 else 0
    output_batch_stride = out.stride(0) if out.ndim == 3 else 0

    grid = lambda meta: (
        batch_size
        * triton.cdiv(M, meta["BLOCK_SIZE_M"])
        * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
    )
    ns_line_1_kernel[grid](
        A_ptr=A,
        C_ptr=out,
        M=M,
        K=K,
        a_stride_b=input_batch_stride,
        a_stride_r=A.stride(-2),
        a_stride_c=A.stride(-1),
        c_stride_b=output_batch_stride,
        c_stride_r=out.stride(-2),
        c_stride_c=out.stride(-1),
    )
    return out


@triton.autotune(
    configs=_get_autotune_configs(),
    key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
)
@triton.jit
def ns_line_2_kernel(
    A_ptr,
    C_ptr,
    M,
    a_stride_b,
    a_stride_r,
    a_stride_c,
    c_stride_b,
    c_stride_r,
    c_stride_c,
    alpha,
    beta,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    LOWER_UPPER: tl.constexpr,
):
    # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A
    # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels
    pid = tl.program_id(axis=0)
    batch_idx, m_idx, n_idx = _pid_to_block(
        pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
    )

    # Skip blocks that don't need to be computed
    skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
    skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
    if skip_block_below_diag or skip_block_above_diag:
        return

    # Index into one matrix of batch
    A_ptr += batch_idx * a_stride_b
    C_ptr += batch_idx * c_stride_b

    # Create pointer arrays for A and A.T
    offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
    at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Accumulate over blocks of K
    for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0)
        at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, at, accumulator)
        a_ptrs += BLOCK_SIZE_K * a_stride_c
        at_ptrs += BLOCK_SIZE_K * a_stride_c

    # Load block of A to add (corresponds to the current block of C)
    offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M)
    offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N)
    a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c)
    a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M)
    a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32)

    # Apply alpha and beta
    accumulator *= alpha
    accumulator += a_add * beta

    out_dtype = C_ptr.dtype.element_ty
    output = accumulator.to(out_dtype)

    # Store block of C
    offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
    tl.store(c_ptrs, output, mask=c_mask)

    # Store block of C mirrored across the diagonal
    c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
    c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
    tl.store(c_ptrs_t, output.T, mask=c_mask_t)


def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor):
    """
    Launch Triton kernel to compute C = alpha * A @ A.T + beta * A
    """
    assert A.ndim == 2 or A.ndim == 3
    M, K = A.shape[-2:]
    assert M == K, "Input matrix must be square"
    assert out.size(-2) == M
    assert out.size(-1) == M

    batch_size = A.size(0) if A.ndim == 3 else 1
    input_batch_stride = A.stride(0) if A.ndim == 3 else 0
    output_batch_stride = out.stride(0) if out.ndim == 3 else 0

    grid = lambda meta: (
        batch_size
        * triton.cdiv(M, meta["BLOCK_SIZE_M"])
        * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
    )
    ns_line_2_kernel[grid](
        A_ptr=A,
        C_ptr=out,
        M=M,
        a_stride_b=input_batch_stride,
        a_stride_r=A.stride(-2),
        a_stride_c=A.stride(-1),
        c_stride_b=output_batch_stride,
        c_stride_r=out.stride(-2),
        c_stride_c=out.stride(-1),
        alpha=alpha,
        beta=beta,
    )
    return out


@torch.compile(
    dynamic=False, fullgraph=True
)  # Must use dynamic=False or else it's much slower
def newton_schulz_triton(G: torch.Tensor):
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    # Allocate buffers
    X = X.contiguous()
    A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
    B = torch.empty_like(A)
    C = torch.empty_like(X)

    ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm

    # Perform the NS iterations
    for _ in range(5):
        ns_line_1(X, out=A)  # A = X @ X.mT
        ns_line_2(A, alpha=c, beta=b, out=B)  # B = b * A + c * A @ A
        ns_line_3(X, B, X, beta=a, out=C)  # C = a * X + B @ X
        X, C = C, X  # Swap references to avoid unnecessary copies

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X


# -----------------------------------------------------------------------------
# Muon optimizer

# class Muon(torch.optim.Optimizer):
#     """
#     Muon - MomentUm Orthogonalized by Newton-schulz

#     https://kellerjordan.github.io/posts/muon/

#     Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
#     processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
#     matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
#     the advantage that it can be stably run in bfloat16 on the GPU.

#     Warning: This optimizer should not be used for the embedding layer, the final fully connected layer,
#     or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
#     """
#     def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95):
#         defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
#         params = list(params)
#         sizes = {p.shape for p in params}
#         # create one buffer per unique parameter-size
#         param_groups = []
#         for size in sizes:
#             group_params = [p for p in params if p.shape == size]
#             param_groups.append(dict(params=group_params))
#         super().__init__(param_groups, defaults)

#     @torch.no_grad()
#     def step(self):
#         # Efficient systems-wise implementation of step developed by @YouJiacheng,
#         # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad,
#         # @ryanyang0, and @vagrawal.
#         rank = dist.get_rank()
#         world_size = dist.get_world_size()
#         reduce_scatter_futures: list[torch.Future] = []
#         all_gather_futures: list[torch.Future] = []
#         for group in self.param_groups:
#             params: list[Tensor] = group["params"]
#             grad = torch.empty_like(params[-1])
#             grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size
#             for base_i in range(0, len(params), world_size):
#                 if base_i + rank < len(params):
#                     grad = params[base_i + rank].grad
#                 # This gives strange dynamo warnings
#                 reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future())

#         idx = 0
#         for group in self.param_groups:
#             params: list[Tensor] = group["params"]
#             params_pad = params + [torch.empty_like(params[-1])] * world_size
#             momentum = group["momentum"]
#             for base_i in range(0, len(params), world_size):
#                 reduce_scatter_futures[idx].wait()
#                 if base_i + rank < len(params):
#                     p = params[base_i + rank]
#                     grad = p.grad
#                     eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0)
#                     eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)
#                     state = self.state[p]
#                     if len(state) == 0:
#                         state["momentum_buffer"] = torch.zeros_like(grad)
#                     momentum_buffer = state["momentum_buffer"]
#                     p.mul_(1 - eff_weight_decay)
#                     momentum_buffer.lerp_(grad, 1 - momentum)
#                     grad = grad.lerp_(momentum_buffer, momentum)
#                     v = newton_schulz_triton(grad)
#                     p.add_(other=v, alpha=-eff_lr)
#                 idx += 1
#                 all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future())
#         torch.futures.collect_all(all_gather_futures).wait()

# class DistAdam(torch.optim.Optimizer):
#     def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
#         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
#         params = list(params)
#         sizes = {p.shape for p in params}
#         # create one buffer per unique parameter-size
#         param_groups = []
#         for size in sizes:
#             group_params = [p for p in params if p.shape == size]
#             param_groups.append(dict(params=group_params))
#         super().__init__(param_groups, defaults)
#         # DistributedAdam implementation by @vagrawal

#     @torch.compile
#     @torch.no_grad()
#     def step(self):
#         rank = dist.get_rank()
#         world_size = dist.get_world_size()
#         reduce_scatter_futures: list[torch.Future] = []
#         all_gather_futures: list[torch.Future] = []
#         grad_slices = []
#         for group in self.param_groups:
#             params: list[Tensor] = group["params"]
#             for base_i in range(len(params)):
#                 grad = params[base_i].grad
#                 rank_size = grad.shape[0] // world_size
#                 grad_slice = torch.empty_like(grad[:rank_size])
#                 reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
#                 grad_slices.append(grad_slice)

#         idx = 0
#         for group in self.param_groups:
#             beta1, beta2 = group['betas']
#             eps = group['eps']
#             wd = group['weight_decay']
#             params = group['params']
#             for base in range(len(params)):
#                 reduce_scatter_futures[idx].wait()
#                 p = params[base]
#                 rank_size = p.shape[0] // world_size
#                 p_slice = p[rank * rank_size:(rank + 1) * rank_size]
#                 lr = group['lr'] * getattr(p, "lr_mul", 1.0)
#                 state = self.state[p]
#                 g_slice = grad_slices[idx]
#                 # State init
#                 if not state:
#                     state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
#                     state['exp_avg'] = torch.zeros_like(p_slice)
#                     state['exp_avg_sq'] = torch.zeros_like(p_slice)
#                 exp_avg = state['exp_avg']
#                 exp_avg_sq = state['exp_avg_sq']
#                 state['step'] += 1
#                 t = state['step']
#                 # weight decay
#                 if wd != 0:
#                     eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
#                     p_slice.mul_(1 - eff_weight_decay)
#                 # update running averages
#                 exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
#                 exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
#                 # bias corrections
#                 bias1 = 1 - beta1 ** t
#                 bias2 = 1 - beta2 ** t
#                 # compute step
#                 denom = exp_avg_sq.sqrt().add_(eps)
#                 step_size = lr * (torch.sqrt(bias2) / bias1)
#                 update = exp_avg.div(denom).mul_(step_size)
#                 p_slice.add_(other=update, alpha=-1.0)
#                 idx += 1
#                 all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
#         torch.futures.collect_all(all_gather_futures).wait()

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model


def norm(x: Tensor):
    return F.rms_norm(x, (x.size(-1),))


class CastedLinear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        use_fp8=False,
        x_s=1.0,
        w_s=1.0,
        grad_s=1.0,
    ):
        super().__init__(in_features, out_features, bias=False)
        self.use_fp8 = use_fp8
        self.x_s = x_s
        self.w_s = w_s
        self.grad_s = grad_s

    def reset_parameters(self) -> None:
        std = 0.5 * (
            self.in_features**-0.5
        )  # 0.5 is a bit better than the default 1/sqrt(3)
        bound = (3**0.5) * std
        with torch.no_grad():
            self.weight.uniform_(-bound, bound)

    def forward(self, x: Tensor):
        if self.use_fp8 and self.training:
            _x = x.flatten(0, -2)
            out: Tensor = torch.ops.nanogpt.mm(
                _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s
            )[0]
            return out.reshape(*x.shape[:-1], -1)
        else:
            return F.linear(x, self.weight.type_as(x))


class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(
            0, 1, steps=dim // 4, dtype=torch.float32
        )
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)

    def forward(self, x_BTHD: Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = (
            self.cos[None, : x_BTHD.size(-3), None, :],
            self.sin[None, : x_BTHD.size(-3), None, :],
        )
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)


class CausalSelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        hdim = num_heads * head_dim
        assert hdim == dim, "num_heads * head_dim must equal model_dim"
        std = 0.5 * (dim**-0.5)
        bound = (3**0.5) * std  # improved init scale by @YouJiacheng
        # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
        # https://x.com/hi_tysam/status/1879699187107033311
        self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim))
        with torch.no_grad():
            self.qkvo_w[:3].uniform_(-bound, bound)  # init QKV weights
            self.qkvo_w[3].zero_()  # init output weights to zero
        self.rotary = Rotary(head_dim, max_seq_len)
        # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
        # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
        self.attn_scale = 0.12

        # sparse gated attention to enable context based no-op by @classiclarryd
        self.attn_gate_dim = 12
        self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads)
        self.attn_gate.weight.detach().zero_()

    def forward(
        self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask
    ):
        B, T = x.size(0), x.size(1)  # batch size, sequence length
        assert B == 1, "Must use batch size = 1 for FlexAttention"

        q, k, v = (
            F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x))
            .view(B, T, 3 * self.num_heads, self.head_dim)
            .chunk(3, dim=-2)
        )
        q, k = norm(q), norm(k)  # QK norm @Grad62304977
        q, k = self.rotary(q), self.rotary(k)
        if ve is not None:
            v = lambdas[0] * v + lambdas[1] * ve.view_as(
                v
            )  # @KoszarskyB & @Grad62304977
        else:  # skip mid-layers token value embeddings by @YouJiacheng
            v = lambdas[0] * v
        y = flex_attention(
            q.transpose(1, 2),
            k.transpose(1, 2),
            v.transpose(1, 2),
            block_mask=block_mask,
            scale=0.12,
        ).transpose(1, 2)
        y = y.view(B, T, self.num_heads, self.head_dim)
        y = y * torch.sigmoid(self.attn_gate(x[..., : self.attn_gate_dim])).view(
            B, T, self.num_heads, 1
        )
        y = y.contiguous().view(
            B, T, self.num_heads * self.head_dim
        )  # re-assemble all head outputs side by side
        y = F.linear(y, self.qkvo_w[3].type_as(y))
        return y


class MLP(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        hdim = 4 * dim
        # make both matrices have the same shape because optimizer sorts params by shape
        # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size
        self.c_fc = nn.Parameter(torch.empty(dim, hdim))
        self.c_proj = nn.Parameter(torch.empty(dim, hdim))
        std = 0.5 * (dim**-0.5)
        bound = (3**0.5) * std  # improved init scale by @YouJiacheng
        with torch.no_grad():
            self.c_fc.uniform_(-bound, bound)
            self.c_proj.zero_()  # zero init suggested by @Grad62304977

    def forward(self, x: Tensor):
        x = F.linear(x, self.c_fc.T.type_as(x))
        x = F.relu(
            x
        ).square()  # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = F.linear(x, self.c_proj.type_as(x))
        return x


class Block(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
        super().__init__()
        # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
        self.attn = (
            CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
        )
        self.mlp = MLP(dim)

    def forward(
        self,
        x: Tensor,
        ve: Tensor | None,
        x0: Tensor,
        lambdas: Tensor,
        sa_lambdas: Tensor,
        block_mask: BlockMask,
    ):
        x = lambdas[0] * x + lambdas[1] * x0
        if self.attn is not None:
            x = x + self.attn(norm(x), ve, sa_lambdas, block_mask)
        x = x + self.mlp(norm(x))
        return x


# -----------------------------------------------------------------------------
# The main model


def next_multiple_of_n(v: float | int, *, n: int):
    return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)


class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_layers: int,
        num_heads: int,
        model_dim: int,
        max_seq_len: int,
    ):
        super().__init__()
        vocab_size = next_multiple_of_n(vocab_size, n=128)
        self.embed = nn.Embedding(vocab_size, model_dim)
        # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
        # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
        self.value_embeds = nn.ModuleList(
            [nn.Embedding(vocab_size, model_dim) for _ in range(3)]
        )
        self.blocks = nn.ModuleList(
            [Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]
        )
        # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
        # suggested to me by @Grad62304977. this originates from Karpathy's experiments.
        use_fp8 = not os.environ.get("DISABLE_FP8", False)
        self.lm_head = CastedLinear(
            model_dim,
            vocab_size,
            use_fp8=use_fp8,
            x_s=(model_dim**0.5) / 448,
            w_s=2**-9,
            grad_s=1 / 448,
        )
        self.lm_head.weight.detach().zero_()  # @Grad62304977
        # Add learnable skip connection weights for decoder layers
        assert num_layers % 2 == 0
        pad = (-num_layers * 5) % dist.get_world_size()
        self.scalars = nn.Parameter(
            torch.cat(
                [
                    torch.ones(num_layers),  # skip_weights
                    *[
                        torch.tensor([1.0, 0.0]) for _ in range(num_layers)
                    ],  # block lambdas
                    *[
                        torch.tensor([0.5, 0.5]) for _ in range(num_layers)
                    ],  # SA lambdas
                    torch.ones(pad),
                ]
            )
        )
        # set learning rates
        for param in self.embed.parameters():
            param.lr_mul = 75.0
        for param in self.value_embeds.parameters():
            param.lr_mul = 75.0
        self.lm_head.weight.lr_mul = 1.0
        self.scalars.lr_mul = 5.0

    def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
        BLOCK_SIZE = 128
        docs = (input_seq == 50256).cumsum(0)

        def document_causal(b, h, q_idx, kv_idx):
            causal_mask = q_idx >= kv_idx
            document_mask = docs[q_idx] == docs[kv_idx]
            return causal_mask & document_mask

        def dense_to_ordered(dense_blockmask: Tensor):
            num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
            indices = (
                dense_blockmask.argsort(dim=-1, descending=False, stable=True)
                .flip(-1)
                .to(torch.int32)
            )
            return num_blocks[None, None].contiguous(), indices[None, None].contiguous()

        # manual block mask creation by @YouJiacheng
        assert len(input_seq) % BLOCK_SIZE == 0
        NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
        block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
        causal_blockmask_any = block_idx[:, None] >= block_idx
        causal_blockmask_all = block_idx[:, None] > block_idx
        docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
        docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
        document_blockmask_any = (docs_low[:, None] <= docs_high) & (
            docs_high[:, None] >= docs_low
        )
        document_blockmask_all = (docs_low[:, None] == docs_high) & (
            docs_high[:, None] == docs_low
        )
        blockmask_any = causal_blockmask_any & document_blockmask_any
        blockmask_all = causal_blockmask_all & document_blockmask_all
        partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(
            blockmask_any & ~blockmask_all
        )
        full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)

        def build_bm(window_size_blocks: Tensor) -> BlockMask:
            return BlockMask.from_kv_blocks(
                torch.clamp_max(
                    partial_kv_num_blocks,
                    torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1),
                ),
                partial_kv_indices,
                torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
                full_kv_indices,
                BLOCK_SIZE=BLOCK_SIZE,
                mask_mod=document_causal,
            )

        # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
        return build_bm(sliding_window_num_blocks), build_bm(
            sliding_window_num_blocks // 2
        )

    def forward(
        self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor
    ):
        assert input_seq.ndim == 1

        ve = [value_embed(input_seq) for value_embed in self.value_embeds]
        # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
        ve = (
            [ve[0], ve[1], ve[2]]
            + [None] * (len(self.blocks) - 6)
            + [ve[0], ve[1], ve[2]]
        )
        assert len(ve) == len(self.blocks)

        long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
        block_masks = [
            long_bm,
            short_bm,
            short_bm,
            short_bm,
            long_bm,
            short_bm,
            short_bm,
            long_bm,
            short_bm,
            short_bm,
            short_bm,
            long_bm,
        ]
        assert len(block_masks) == len(self.blocks)

        x = x0 = norm(self.embed(input_seq)[None])  # use of norm here by @Grad62304977

        # U-net design by @brendanh0gan
        skip_connections = []
        skip_weights = self.scalars[: (len(self.blocks) // 2)]
        lambdas = self.scalars[1 * len(self.blocks) : 3 * len(self.blocks)].view(-1, 2)
        sa_lambdas = self.scalars[3 * len(self.blocks) : 5 * len(self.blocks)].view(
            -1, 2
        )

        n = len(self.blocks) // 2

        for i in range(len(self.blocks)):
            if i >= n:
                x = x + skip_weights[i - n] * skip_connections.pop()
            x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i])
            if i < n:
                skip_connections.append(x)

        x = norm(x)
        logits = self.lm_head(x).float()
        # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
        logits = 30 * torch.sigmoid(logits / 7.5)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            target_seq,
            reduction="sum" if self.training else "mean",
        )
        return loss


# -----------------------------------------------------------------------------
# Distributed data loader

# def _load_data_shard(file: Path):
#     header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
#     assert header[0] == 20240520, "magic number mismatch in the data .bin file"
#     assert header[1] == 1, "unsupported version"
#     num_tokens = int(header[2]) # number of tokens (claimed)
#     with file.open("rb", buffering=0) as f:
#         tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng
#         f.seek(256 * 4)
#         nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
#         assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
#     return tokens

# # find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap by @classiclarryd
# def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int):
#     boundary_mask = tokens[pos : pos + token_window] == 50256
#     boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos
#     start = boundary_positions[0].item()
#     starts = []
#     for i in range(1, len(boundary_positions)):
#         end = boundary_positions[i].item()
#         if end - start >= seq_len:
#             starts.append(start) # append start once end pos is confirmed
#             if len(starts) == dist.get_world_size():
#                 return starts, end - pos
#             start = end
#     assert False # increase token_window if necessary

# def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool):
#     rank = dist.get_rank()
#     world_size = dist.get_world_size()
#     batch_size = seq_len * world_size
#     files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
#     file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
#     tokens, pos = _load_data_shard(next(file_iter)), 0
#     while True:
#         token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len
#         if pos + token_window + 1 >= len(tokens):
#             tokens = _load_data_shard(next(file_iter))
#             pos = 0
#         for _ in range(grad_accum_steps):
#             if align_to_bos:
#                 batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window)
#                 start_idx = batch_starts[rank]
#             else:
#                 tokens_consumed = batch_size
#                 start_idx = pos + rank * seq_len
#             buf = tokens[start_idx:][:seq_len + 1]
#             inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
#             targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
#             pos += tokens_consumed
#             token_window -= tokens_consumed
#             yield inputs, targets

# # -----------------------------------------------------------------------------
# # int main

# @dataclass
# class Hyperparameters:
#     # data
#     train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on
#     val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on
#     val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
#     train_seq_len = 48*1024 # FlexAttention sequence length
#     val_seq_len = 4*64*1024 # FlexAttention sequence length for validation
#     # optimization
#     num_iterations = 1695 # number of iterations to run
#     cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate
#     # evaluation and logging
#     run_id = uuid.uuid4()
#     val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end
#     save_checkpoint = False
# args = Hyperparameters()

# data_path = os.environ.get("DATA_PATH", ".")
# args.train_files = os.path.join(data_path, args.train_files)
# args.val_files = os.path.join(data_path, args.val_files)

# # torchrun sets these env variables
# rank = int(os.environ["RANK"])
# world_size = int(os.environ["WORLD_SIZE"])
# assert 8 % world_size == 0, "world_size must be a divisor of 8"
# grad_accum_steps = 8 // world_size
# assert torch.cuda.is_available()
# device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
# torch.cuda.set_device(device)
# dist.init_process_group(backend="nccl", device_id=device)
# dist.barrier()
# master_process = (rank == 0) # this process will do logging, checkpointing etc.

# # begin logging
# logfile = None
# if master_process:
#     run_id = args.run_id
#     os.makedirs("logs", exist_ok=True)
#     logfile = f"logs/{run_id}.txt"
#     print(logfile)
# def print0(s, console=False):
#     if master_process:
#         with open(logfile, "a") as f:
#             if console:
#                 print(s)
#             print(s, file=f)

# # begin by printing this file (the Python code)
# print0(code)
# print0("="*100)
# # log information about the hardware/software environment this is running on
# print0(f"Running Python {sys.version}")
# print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
# print0(f"Running Triton version {triton.__version__}")
# def nvidia_smi():
#     import subprocess  # avoid top level import
#     return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
# print0(nvidia_smi())
# print0("="*100)

# model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
# for m in model.modules():
#     if isinstance(m, nn.Embedding):
#         m.bfloat16()
# for param in model.parameters():
#     dist.broadcast(param.detach(), 0)

# # collect the parameters to optimize
# hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
# embed_params = [p for n, p in model.named_parameters() if "embed" in n]
# scalar_params = [p for p in model.parameters() if p.ndim < 2]
# head_params = [model.lm_head.weight]

# # init the optimizer(s)
# # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence
# # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094
# optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)
# optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0)
# optimizers = [optimizer1, optimizer2]
# for opt in optimizers:
#     for group in opt.param_groups:
#         group["initial_lr"] = group["lr"]

# # learning rate schedule: stable then decay
# def get_lr(step: int):
#     x = step / args.num_iterations # progress in training
#     assert 0 <= x < 1
#     if x < 1 - args.cooldown_frac:
#         return 1.0
#     else:
#         w = (1 - x) / args.cooldown_frac
#         return w * 1.0 + (1 - w) * 0.1


# # attention window size schedule: linearly increase
# @lru_cache(1)
# def get_window_size_blocks_helper(window_size: int):
#     return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# def get_window_size_blocks(step: int):
#     x = step / args.num_iterations # progress in training
#     assert 0 <= x <= 1
#     # Linearly increase the block-wise sliding window size over training 128 -> 1792
#     # increase by @fernbear.bsky.social; block-wise by @YouJiacheng
#     window_size = next_multiple_of_n(1728 * x, n=128)
#     return get_window_size_blocks_helper(window_size)
def get_window_size_blocks(progress=0.5, device="cuda"):
    window_size = next_multiple_of_n(1728 * progress, n=128)
    return torch.tensor(window_size // 128, dtype=torch.int32, device=device)


# model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True)

# ########################################
# #            Warmup kernels            #
# ########################################

# # Warmup the training kernels, then re-initialize the state so we aren't cheating
# warmup_steps = 10
# initial_state = dict(model=copy.deepcopy(model.state_dict()),
#                      optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state
# train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True)
# for _ in range(warmup_steps):
#     inputs, targets = next(train_loader)
#     model(inputs, targets, get_window_size_blocks(1)).backward()
#     for opt in optimizers:
#         opt.step()
#     model.zero_grad(set_to_none=True)
# model.load_state_dict(initial_state["model"])
# for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
#     opt.load_state_dict(opt_state)
# del train_loader, initial_state

# ########################################
# #        Training and validation       #
# ########################################

# train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True)
# training_time_ms = 0
# # start the clock
# torch.cuda.synchronize()
# t0 = time.perf_counter()
# # begin training
# train_steps = args.num_iterations
# for step in range(train_steps + 1):
#     last_step = (step == train_steps)

#     # --------------- VALIDATION SECTION -----------------
#     if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
#         # stop the clock
#         torch.cuda.synchronize()
#         training_time_ms += 1000 * (time.perf_counter() - t0)
#         model.eval()
#         val_batch_size = world_size * args.val_seq_len
#         assert args.val_tokens % val_batch_size == 0
#         val_steps = args.val_tokens // val_batch_size
#         val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False)
#         val_loss = 0
#         with torch.no_grad():
#             for _ in range(val_steps):
#                 inputs, targets = next(val_loader)
#                 val_loss += model(inputs, targets, get_window_size_blocks(step))
#         val_loss /= val_steps
#         del val_loader
#         dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
#         print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True)
#         model.train()
#         # start the clock again
#         torch.cuda.synchronize()
#         t0 = time.perf_counter()

#     if last_step:
#         if master_process and args.save_checkpoint:
#             log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
#             os.makedirs(f"logs/{run_id}", exist_ok=True)
#             torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt")
#         # the last step only has the validation loop, so break to avoid training
#         break

#     # --------------- TRAINING SECTION -----------------
#     for _ in range(grad_accum_steps):
#         inputs, targets = next(train_loader)
#         model(inputs, targets, get_window_size_blocks(step)).backward()
#     # set optimization hyperparameters
#     for opt in optimizers:
#         for group in opt.param_groups:
#             group["lr"] = group["initial_lr"] * get_lr(step)
#     for group in optimizer2.param_groups:
#         frac = min(step / 300, 1) # momentum warmup for muon
#         group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
#     # step the optimizers
#     for opt in optimizers:
#         opt.step()
#     # null the gradients
#     model.zero_grad(set_to_none=True)
#     # logging
#     approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
#     print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True)

# print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
#        f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
# dist.destroy_process_group()
